#!/usr/bin/env python3
"""
orig/mass_gap_core.py

Core mass-gap functions extracted from scripts/run_mass_gap.py
for direct import by Volume-4 modules.
"""

from __future__ import annotations
import numpy as np
from scipy.linalg import expm
from scipy.sparse.linalg import eigsh
from sim_utils import seed_all, save_csv

def logistic_D(n: np.ndarray, k: float, n0: float) -> np.ndarray:
    """Compute the fractal dimension D via logistic function."""
    return 1.0 + 2.0 / (1.0 + np.exp(k * (n - n0)))

def g_of_D(D: np.ndarray, a: float, b: float) -> np.ndarray:
    """Linear pivot weight g(D) = a*D + b."""
    return a * D + b

def expand_array(arr: np.ndarray, target_length: int) -> np.ndarray:
    """Tile or truncate an array to exactly target_length."""
    if arr.size == target_length:
        return arr.copy()
    repeats = int(np.ceil(target_length / arr.size))
    return np.tile(arr, repeats)[:target_length]

def build_CMO(Umu: np.ndarray) -> np.ndarray:
    """Build the Composite Moment Operator from link variables."""
    N = Umu.shape[0]
    CMO = np.zeros((N, N), dtype=float)
    for i in range(N):
        for j in range(N):
            # real part of <U_i U_j†>
            CMO[i, j] = np.trace(Umu[i] @ Umu[j].conj().T).real \
                        if Umu.ndim > 1 else \
                        (Umu[i] * Umu[j].conj()).real
    return CMO

def mass_gap_from_Umu(Umu: np.ndarray) -> float:
    """Return the smallest nonzero eigenvalue of the CMO."""
    C = build_CMO(Umu)
    try:
        eigs = eigsh(C, k=2, which='SM', return_eigenvectors=False)
        vals = np.sort(np.real(eigs))
    except Exception:
        vals = np.sort(np.real(np.linalg.eigvals(C)))
    # skip zero mode
    for v in vals:
        if abs(v) > 1e-8:
            return float(v)
    return 0.0

def compute_mass_gaps(
    flip_counts: np.ndarray,
    kernel: np.ndarray,
    a: float,
    b: float,
    k: float,
    n0: float,
    L: int,
    rng: np.random.Generator | None = None
) -> dict[str, float]:
    """
    Compute per-gauge mass gaps using approximate logistic pivot + finite-size noise.
    Returns { 'U1': ..., 'SU2': ..., 'SU3': ... }.
    """
    if rng is None:
        rng = np.random.default_rng()

    # apply pivot to every link
    N_links = flip_counts.size
    fc = expand_array(flip_counts, N_links)
    D = logistic_D(fc, k, n0)
    m_theory = float(np.mean(g_of_D(D, a, b)))

    # finite-size corrections per gauge
    factors = {'U1': (1.0, 0.5), 'SU2': (0.7, 0.35), 'SU3': (0.5, 0.25)}
    results: dict[str, float] = {}
    for gauge, (c1, c2) in factors.items():
        base = m_theory + c1 / L + c2 / (L*L)
        noise = rng.uniform(-0.05, 0.05) * m_theory
        results[gauge] = float(base + noise)
    return results

def run_mass_gap(
    b: float,
    k: float,
    n0: float,
    L: int,
    ensemble_size: int,
    flip_counts_path: str,
    kernel_path: str
) -> list[float]:
    """
    Run the mass-gap ensemble for one (b,k,n0,L).
    Returns a list of `ensemble_size` mass-gap samples.
    """
    flip_counts = np.load(flip_counts_path, allow_pickle=True)
    kernel      = np.load(kernel_path,    allow_pickle=True)

    rng = np.random.default_rng()
    samples: list[float] = []
    for _ in range(ensemble_size):
        noise = rng.uniform(-0.1, 0.1, size=flip_counts.shape)
        fc_pert = flip_counts * (1.0 + noise)
        gaps = compute_mass_gaps(
            flip_counts=fc_pert,
            kernel=kernel,
            a=b,       # pivot intercept
            b=k,       # pivot slope
            k=k,
            n0=n0,
            L=L,
            rng=rng
        )
        # record only one gauge if your run.py expects a single value,
        # or record all and let run.py decide which to aggregate
        samples.append(gaps["SU2"])  # <– or SU3/U1 as needed
    return samples
